import time
import numpy as np
from marl.train import marl_training


def run_marl(dataset: str, episodes: int = 100, eval_timeout: int = 300):
    start = time.time()
    env = marl_training(dataset_name=dataset, episodes=episodes, eval_timeout=eval_timeout)
    stats = env.get_pipeline_statistics()
    duration = time.time() - start
    return {
        "dataset": dataset,
        "episodes": episodes,
        "eval_timeout": eval_timeout,
        "best_success_rate": float(stats.get("success_rate", 0.0)),
        "total_pipelines": int(stats.get("total_pipelines", 0)),
        "time_sec": duration,
    }


if __name__ == "__main__":
    import argparse
    p = argparse.ArgumentParser()
    p.add_argument("--dataset", default="iris")
    p.add_argument("--episodes", type=int, default=100)
    p.add_argument("--eval-timeout", type=int, default=300)
    args = p.parse_args()
    res = run_marl(args.dataset, args.episodes, args.eval_timeout)
    print(res)
